热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

鸡翅|时会_pytorch(网络模型)

篇首语:本文由编程笔记#小编为大家整理,主要介绍了pytorch(网络模型)相关的知识,希望对你有一定的参考价值。上一篇神经网络鸡翅nn.Mod

篇首语:本文由编程笔记#小编为大家整理,主要介绍了pytorch(网络模型)相关的知识,希望对你有一定的参考价值。


上一篇


神经网络鸡翅nn.Module


官网

import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))# 卷积、非线性处理
return F.relu(self.conv2(x))

练习

import torch
import torch.nn as nn
import torch.nn.functional as F
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
return x+1
dun=Dun()
x=torch.tensor(1.0)# 转化类型
output=dun(x);# 调用forward
print(output)# 输出

卷积层


import torch
input=torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
kernel=torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
print(input.shape)# 输出尺寸
print(kernel.shape)
input=torch.reshape(input,(1,1,5,5))# 类型转换
kernel=torch.reshape(kernel,(1,1,3,3))# 类型转换
print(input)
print(kernel)
print(input.shape)
print(kernel.shape)
# 卷积操作
out= F.conv2d(input,kernel,stride=1)
print(out)
out= F.conv2d(input,kernel,stride=2)
print(out)
# 填充
out= F.conv2d(input,kernel,stride=1,padding=1)
print(out)

输出chanel是2时

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader= DataLoader(dataset,batch_size=64)
# 卷积类
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.conv1=Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)
def forward(self,x):
return self.conv1(x)
dun=Dun()
print(dun)
writer= SummaryWriter("./logs")
step=0
# 卷积操作
for data in dataloader:
img,target=data
output=dun(img)
print(img.shape)
print(output.shape)
writer.add_images("input",img,step)
output=torch.reshape(output,(-1,3,30,30))# -1时会根据后面的值自动计算
writer.add_images("output",output,step)
step+=1
writer.close()


池化层

作用:就像高清视频换成低清视频


import torch
from torch import nn
from torch.nn import MaxPool2d
input =torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]],dtype=torch.float32)
input=torch.reshape(input,(-1,1,5,5))
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.maxpool=MaxPool2d(kernel_size=3,ceil_mode=True)# ceil_model false和True的结果和预期的一致
def forward(self,inut):
return self.maxpool(input)
dun=Dun()
out=dun(input)
print(out)


图片处理

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader= DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.maxpool=MaxPool2d(kernel_size=3,ceil_mode=True)# ceil_model false和True的结果和预期的一致
def forward(self,input):
return self.maxpool(input)
dun=Dun()
step=0
writer=SummaryWriter("logs")
for data in dataloader:
img,target=data
writer.add_images("input",img,step)
output=dun(img)
writer.add_images("output",output,step)
step+=1
writer.close()


非线性激活

非线性变换目的是引入非线性特征,可以更好地处理信息


ReLU

import torch
from torch import nn
from torch.nn import ReLU
input= torch.tensor([[1,-0.5],[-1,3]])
input=torch.reshape(input,(-1,1,2,2))
print(input.shape)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.relu1=ReLU()
def forward(self,input):
return self.relu1(input)
dun=Dun()
output=dun(input)
print(output)


sigmoid

import torch
import torchvision
from torch import nn
from torch.nn import ReLU, Sigmoid
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,download=True,transform=torchvision.transforms.ToTensor())
dataloader=DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.relu1=ReLU()
self.sigmoid=Sigmoid()
def forward(self,input):
return self.sigmoid(input)
dun=Dun()
writer=SummaryWriter("./logs")
step=0
for data in dataloader:
img,target=data
writer.add_images("input",img,global_step=step)
output=dun(img)
writer.add_images("output",output,global_step=step)
step+=1
writer.close()


线性层


import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=64)
class Dun(nn.Module):
def __init__(self):
super().__init__()
self.linear=Linear(196608,10)
def forward(self,input):
return self.linear(input)
dun=Dun()
for data in dataloader:
img,target=data
print(img.shape)
# input=torch.reshape(img,(1,1,1,-1))
input= torch.flatten(img)# 将数据展平一行,可以代替上面的一行
print(input.shape)
output=dun(input)
print(output.shape)


正则化层

加快神经网络地训练速度

# With Learnable Parameters
m = nn.BatchNorm2d(100)
# Without Learnable Parameters
m = nn.BatchNorm2d(100, affine=False)
input = torch.randn(20, 100, 35, 45)
output = m(input)

其他层有Recurrent Layers、Transformer Layers、Linear Layers等


简单的网络模型

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.
self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10))
# 1.
# self.conv1=Conv2d(3,32,5,padding=2)
# self.maxpool1=MaxPool2d(2)
# self.conv2=Conv2d(32,32,5,padding=2)
# self.maxpool2=MaxPool2d(2)
# self.conv3=Conv2d(32,64,5,padding=2)
# self.maxpool3=MaxPool2d(2)
# self.flatten=Flatten()
# self.linear1=Linear(1024,64)
# self.linear2=Linear(64,10)
def forward(self,x):
x=self.model1(x)
return x
dun=Dun()
# 测试
input=torch.ones((64,3,32,32))
print(dun(input).shape)
writer=SummaryWriter("./logs")
writer.add_graph(dun,input)
writer.close()


loss function


L1Loss、MSELoss

import torch
from torch.nn import L1Loss
from torch import nn
input=torch.tensor([1,2,3],dtype=torch.float32)
targrt=torch.tensor([1,2,5],dtype=torch.float32)
loss=L1Loss(reduction="sum")# 该参数有sum和mean两种,默认是mean
print(loss(input,targrt))
loss_mse=nn.MSELoss()
print(loss_mse(input,targrt))


CROSSENTROPYLOSS

import torch
from torch.nn import L1Loss
from torch import nn
x=torch.tensor([0.1,0.2,0.3])
y=torch.tensor([1])
x=torch.reshape(x,(1,3))
loss_cross=nn.CrossEntropyLoss()
print(loss_cross(x,y))


使用

import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("./data_set_test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset=dataset,batch_size=1)
# 分类神经网络
class Dun(nn.Module):
def __init__(self):
super().__init__()
# 2.

推荐阅读
  • 本文详细介绍了Akka中的BackoffSupervisor机制,探讨其在处理持久化失败和Actor重启时的应用。通过具体示例,展示了如何配置和使用BackoffSupervisor以实现更细粒度的异常处理。 ... [详细]
  • 本文将介绍如何编写一些有趣的VBScript脚本,这些脚本可以在朋友之间进行无害的恶作剧。通过简单的代码示例,帮助您了解VBScript的基本语法和功能。 ... [详细]
  • 本文详细介绍了如何在Linux系统上安装和配置Smokeping,以实现对网络链路质量的实时监控。通过详细的步骤和必要的依赖包安装,确保用户能够顺利完成部署并优化其网络性能监控。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 深入理解 SQL 视图、存储过程与事务
    本文详细介绍了SQL中的视图、存储过程和事务的概念及应用。视图为用户提供了一种灵活的数据查询方式,存储过程则封装了复杂的SQL逻辑,而事务确保了数据库操作的完整性和一致性。 ... [详细]
  • 深入解析Spring Cloud Ribbon负载均衡机制
    本文详细介绍了Spring Cloud中的Ribbon组件如何实现服务调用的负载均衡。通过分析其工作原理、源码结构及配置方式,帮助读者理解Ribbon在分布式系统中的重要作用。 ... [详细]
  • 在前两篇文章中,我们探讨了 ControllerDescriptor 和 ActionDescriptor 这两个描述对象,分别对应控制器和操作方法。本文将基于 MVC3 源码进一步分析 ParameterDescriptor,即用于描述 Action 方法参数的对象,并详细介绍其工作原理。 ... [详细]
  • Android 渐变圆环加载控件实现
    本文介绍了如何在 Android 中创建一个自定义的渐变圆环加载控件,该控件已在多个知名应用中使用。我们将详细探讨其工作原理和实现方法。 ... [详细]
  • golang常用库:配置文件解析库/管理工具viper使用
    golang常用库:配置文件解析库管理工具-viper使用-一、viper简介viper配置管理解析库,是由大神SteveFrancia开发,他在google领导着golang的 ... [详细]
  • Explore how Matterverse is redefining the metaverse experience, creating immersive and meaningful virtual environments that foster genuine connections and economic opportunities. ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 本文介绍如何使用Objective-C结合dispatch库进行并发编程,以提高素数计数任务的效率。通过对比纯C代码与引入并发机制后的代码,展示dispatch库的强大功能。 ... [详细]
  • 本文深入探讨 MyBatis 中动态 SQL 的使用方法,包括 if/where、trim 自定义字符串截取规则、choose 分支选择、封装查询和修改条件的 where/set 标签、批量处理的 foreach 标签以及内置参数和 bind 的用法。 ... [详细]
  • 本文详细介绍了Java中org.eclipse.ui.forms.widgets.ExpandableComposite类的addExpansionListener()方法,并提供了多个实际代码示例,帮助开发者更好地理解和使用该方法。这些示例来源于多个知名开源项目,具有很高的参考价值。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
author-avatar
一直很哇塞
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有